from Network.network import Network
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import copy
from Network.network_utils import reduce_function, get_acti
from Network.General.Factor.Pair.pair import PairNetwork
from Network.Dists.mask_utils import expand_mask
from Network.General.Factor.factor_utils import init_decode
from Network.General.Factor.factored import return_values

class KeyPairNetwork(Network):
    '''
    first_obj_dim defines a set of keys, and the rest of the input are the queries.
    Compares each of the keys with all of the queries by performing a pairnet computation
    like a pairnet, but does not assume only 1 key
    3 return settings: 
        query_aggregate and no aggregate final (batch, keys, queries, output dim)
        no aggregate final and no query pair (keys, output dim),
        aggregate_final (batch, output dim)  
    '''
    def __init__(self, args):
        super().__init__(args)
        self.fp = args.factor
        self.fnp = args.factor_net
        self.embed_dim = args.embed_dim
        self.append_keys = self.fnp.append_keys
        self.append_zero_keys = (not self.fnp.append_keys) and (self.fnp.num_pair_layers > 1)
        layers = list()

        # pairnets assume keys/queries are already embedded using key_query
        # args.factor.embed_dim is the embedded dimension
        # initialize the internal layers of the pointnet
        self.conv_layers = list()
        pair_args = copy.deepcopy(args)
        kq_dim = args.factor.key_dim + args.factor.query_dim if self.append_keys or self.append_zero_keys else args.factor.query_dim
        pair_args.factor_net.append_zero_keys = self.append_zero_keys
        pair_args.object_dim = self.embed_dim if self.embed_dim > 0 else kq_dim
        pair_args.num_outputs = args.output_dim if self.embed_dim <= 0 else self.embed_dim
        pair_args.output_dim = args.output_dim if self.embed_dim <= 0 else self.embed_dim
        pair_args.aggregate_final = self.fp.query_aggregate
        pair_args.factor_net.no_decode = True
        if self.embed_dim > 0: pair_args.activation_final = pair_args.activation
        self.pair_args = pair_args
        # print (self.layer_conv_dim, self.hs[-1], args.num_outputs, self.conv_object_dim)
        self.pair_layer = PairNetwork(pair_args)
        layers.append(self.pair_layer)

        self.aggregate_final = args.aggregate_final
        # self.softmax = nn.Softmax(-1)
        # the final embed dim is the output dimension of the pair network
        # if self.embed_dim > 0:
        if not self.fnp.no_decode or self.aggregate_final or (self.fp.query_aggregate and self.embed_dim>0):
            args.factor.final_embed_dim = self.embed_dim if self.embed_dim > 0 else args.output_dim
            self.decode = init_decode(args)
            layers.append(self.decode)

        self.model = layers
        self.train()
        self.reset_network_parameters()
    
    def reappend_queries(self, x, xi):
        return torch.cat([xi, x[...,self.embed_dim:]], dim=-1)

    def forward(self, key, query, mask, ret_settings):
        # assumes keys: [batch, num_keys, key_dim], mask: [batch, num_keys, num_factors]
        embeddings = list()
        for i in range(key.shape[1]):
            n_key = key[:,i:i+1]
            n_mask = mask[:, i:i+1]
            embeddings.append(self.pair_layer(n_key, query, n_mask, list())[0])
        embeddings = torch.stack(embeddings, dim=1) # stack where the keys are
        x, reduction = embeddings, None
        if self.aggregate_final:
            # combine the conv outputs using the reduce function, and append any post channels
            x = reduce_function(self.fnp.reduce_function, x, dim=1)
            x = x.view(x.shape[0], -1)
            reduction = x
            # final network goes to batch, num_outputs
            x = self.decode(x)
        else:
            # when dealing with num_query outputs
            if not self.fp.query_aggregate:
                # current shape of x is batch, keys, embed_dim* queries
                x = x.reshape(x.shape[0], x.shape[1], query.shape[1], -1)
                if self.embed_dim > 0:
                    query_out = list()
                    for i in range(key.shape[1]): # TODO: 2D conv would make this more efficient
                        query_out.append(self.decode(x[:,i].transpose(-1,-2)).transpose(-1,-2))
                    x = torch.stack(query_out, dim=1)
            else:
                if not self.fnp.no_decode or (self.fp.query_aggregate and self.embed_dim>0):
                    x = self.decode(x.transpose(-1,-2))
                    x = x.transpose(-1,-2)
                x = x.reshape(x.shape[0], -1)
        return return_values(ret_settings, x, key, query, embeddings, reduction, mask=mask)